#!/usr/bin/python3
# -*- coding: utf-8 -*-
import os
import argparse
import threading
import queue
import json
import re
import subprocess
from pysnmp.hlapi import *

import AutoExecUtils


class SnmpCollectWorker(threading.Thread):
    def __init__(self, name, execQueue, snmpPort, db, pluginDir, communities, unknownItems):
        threading.Thread.__init__(self, name=name, daemon=True)
        self.goToStop = False
        self._queue = execQueue
        self.snmpPort = snmpPort
        self.ruleCollection = db['_discovery_rule']
        self.errorCount = 0
        self.pluginDir = pluginDir
        self.communities = communities
        self.unknownItems = unknownItems

    def getObjCatByRule(self, objInfo):
        matched = False
        sysObjId = objInfo.get('SYS_OBJECT_ID')
        sysDescr = objInfo.get('SYS_DESCR')
        collection = self.ruleCollection
        for rule in collection.find({'sysObjectId': sysObjId}, {'_id': False}):
            descrPattern = rule.get('sysDescrPattern')
            if descrPattern is None or re.match(descrPattern, sysDescr):
                matched = True
                objInfo['_OBJ_CATEGORY'] = rule['_OBJ_CATEGORY']
                objInfo['_OBJ_TYPE'] = rule['_OBJ_TYPE']
                objInfo['VENDOR'] = rule['VENDOR']
                objInfo['MODEL'] = rule['MODEL']
                break

        return matched

    def snmpGetSysInfo(self, objInfo, communities):
        workCommunity = None
        snmpQuery = False
        ip = objInfo.get('MGMT_IP')
        for community in communities:
            iterator = getCmd(
                SnmpEngine(),
                # UsmUserData(
                #     'usr-sha-aes128', 'authkey1', 'privkey1',
                #     authProtocol=USM_AUTH_HMAC96_SHA,
                #     privProtocol=USM_PRIV_CFB128_AES
                # ),
                CommunityData(community),
                UdpTransportTarget((ip, self.snmpPort)),
                ContextData(),
                ObjectType(ObjectIdentity('1.3.6.1.2.1.1.2.0')),
                ObjectType(ObjectIdentity('1.3.6.1.2.1.1.1.0')),
                ObjectType(ObjectIdentity('1.3.6.1.2.1.1.5.0'))
            )

            errorIndication, errorStatus, errorIndex, varBinds = next(iterator)

            if errorIndication:
                print(errorIndication)

            elif errorStatus:
                print('%s at %s' % (errorStatus.prettyPrint(),
                                    errorIndex and varBinds[int(errorIndex) - 1][0] or '?'))

            else:
                for varBind in varBinds:
                    oid = str(varBind[0])
                    oidVal = str(varBind[1])
                    if oid == '1.3.6.1.2.1.1.2.0':
                        objInfo['SYS_OBJECT_ID'] = '.' + oidVal
                        snmpQuery = True
                    elif oid == '1.3.6.1.2.1.1.1.0':
                        objInfo['SYS_DESCR'] = oidVal
                    elif oid == '1.3.6.1.2.1.1.5.0':
                        objInfo['DEV_NAME'] = oidVal

            if snmpQuery:
                print("INFO: Snmp get for {} success.\n".format(ip), end='')
                if self.getObjCatByRule(objInfo):
                    print("INFO: Device regconized rule matched for {}.\n".format(ip), end='')
                    workCommunity = community
                else:
                    print("INFO: There no device regconized rule matched for {}.\n".format(ip), end='')

        return workCommunity

    def getCollectCmd(self, objInfo, nodeName, community):
        cmd = None
        objCat = objInfo.get('_OBJ_CATEGORY')
        objType = objInfo.get('_OBJ_TYPE')

        nodeInfo = {
            'resourceId': 0,
            'nodeName': nodeName,
            'host': objInfo.get('MGMT_IP'),
            'port': None,
            'protocol': 'snmp',
            'protocolPort': 161,
            'username': 'none',
            'password': community,
            'nodeType': objType
        }

        nodeJsonStr = json.dumps(nodeInfo, ensure_ascii=False)

        if objCat == 'SWITCH':
            cmd = "{}/switchcollector --objtype {}".format(self.pluginDir, objType)
        elif objCat == 'LOADBALANCER':
            if objType == 'F5':
                cmd = "{}/f5collector".format(self.pluginDir)
            elif objType == 'A10':
                cmd = "{}/a10collector".format(self.pluginDir)
        elif objCat == 'SECDEV':
            if objType == 'FireWall':
                cmd = "{}/firewallcollector --type auto".format(self.pluginDir)
        elif objCat == 'FCDEV':
            if objCat == 'FCSwitch ':
                cmd = "{}/storagecollector --type auto".format(self.pluginDir)

        if cmd is not None:
            cmd = "{} --node '{}'".format(cmd, nodeJsonStr)
            ip = objInfo['MGMT_IP']
            if not os.path.exists(ip):
                os.mkdir(ip)

            cmd = "cd '{}' && {}".format(ip, cmd)

        return cmd

    def saveOneNode(self, ip):
        saveCmd = "{}/savedata --outputfile '{}/output.json'".format(self.pluginDir, ip)
        ret = os.system(saveCmd)
        return ret

    def collect(self, objInfo):
        hasError = 0

        detected = False
        workCommunity = None
        # Detech objType by snmp sysObjectId and sysDescr, sysName

        workCommunity = self.snmpGetSysInfo(objInfo, self.communities)
        if workCommunity is None:
            # snmp获取sysObjectId并判断类型失败
            hasError = hasError + 1
            objInfo['_OBJ_CATEGORY'] = 'UNKNOWN'
            self.unknownItems.append(objInfo)

        if hasError == 0:
            ret = 0
            # Call collect tool
            collectCmd = self.getCollectCmd(objInfo, objInfo.get('DEV_NAME'), workCommunity)
            if collectCmd is not None:
                print("INFO: Collect information for " + objInfo['MGMT_IP'] + ".\n", end='')
                ret = os.system(collectCmd)
                hasError = hasError + ret
            else:
                hasError = hasError + 1
            if hasError != 0:
                self.unknownItems.append(objInfo)

        # call savedata
        if hasError == 0:
            #print(json.dumps(objInfo, indent=4, ensure_ascii=False), end='')
            print("INFO: Object catetory:{} object type:{} IP:{} collected.\n".format(objInfo.get('_OBJ_CATEGORY'), objInfo.get('_OBJ_TYPE'), objInfo.get('MGMT_IP')), end='')
            hasError = hasError + self.saveOneNode(objInfo.get('MGMT_IP'))

    def run(self):
        while not self.goToStop:
            objInfo = self._queue.get()
            if objInfo is None:
                break
            self.collect(objInfo)


class NmapScan:
    def __init__(self):
        pass

    def parse(self, hostInfo):
        OSDesc = hostInfo.get('OS')

        if OSDesc is None:
            return None

        objInfo = None
        typesDef = [
            {'key': 'Linux', '_OBJ_CATEGORY': 'OS', '_OBJ_TYPE': 'Linux'},
            {'key': 'Windows', '_OBJ_CATEGORY': 'OS', '_OBJ_TYPE': 'Windows'},
            {'key': 'AIX', '_OBJ_CATEGORY': 'OS', '_OBJ_TYPE': 'AIX'},
            {'key': 'SunOS', '_OBJ_CATEGORY': 'OS', '_OBJ_TYPE': 'SunOS'},
            {'key': 'FreeBSD', '_OBJ_CATEGORY': 'OS', '_OBJ_TYPE': 'FreeBSD'},
            {'key': 'CISCO', '_OBJ_CATEGORY': 'SWITCH', '_OBJ_TYPE': None},
            {'key': 'Huawei', '_OBJ_CATEGORY': 'SWITCH', '_OBJ_TYPE': None},
            {'key': 'Juniper', '_OBJ_CATEGORY': 'SWITCH', '_OBJ_TYPE': None},
            {'key': 'Ruijie', '_OBJ_CATEGORY': 'SWITCH', '_OBJ_TYPE': None},
            {'key': 'H3C', '_OBJ_CATEGORY': 'SWITCH', '_OBJ_TYPE': None},
            {'pattern': re.compile(r'Google Android'), '_OBJ_CATEGORY': 'MOBILE', '_OBJ_TYPE': 'Android'}
        ]

        for matchItem in typesDef:
            key = matchItem.get('key')
            pattern = matchItem.get('pattern')

            if key is not None:
                try:
                    OSDesc.index(key)
                    objInfo = {
                        '_OBJ_CATEGORY': matchItem['_OBJ_CATEGORY'],
                        '_OBJ_TYPE': matchItem['_OBJ_TYPE'],
                        'MGMT_IP': hostInfo['IP']
                    }
                    break
                except ValueError:
                    pass
            elif pattern is not None:
                if re.search(pattern, OSDesc):
                    objInfo = {
                        '_OBJ_CATEGORY': matchItem['_OBJ_CATEGORY'],
                        '_OBJ_TYPE': matchItem['_OBJ_TYPE'],
                        'MGMT_IP': hostInfo['IP']
                    }
                    break

        return objInfo

    def scan(self, net=None, ports=None, timingTmpl=5):
        if ports is None:
            ports = '22,161,135,139,445,3389'
        # nmap -oG - 192.168.0.1/24,192.168.1.1/24 -p 22,161,135,139,445,3389 -T5 -sSU --top-ports 100 -n -O > /tmp/nmap.txt
        nmapCmd = 'sudo nmap -oG - {} -p {} -T{} -O -n -sSU --top-ports 5'.format(net, ports, timingTmpl)
        child = subprocess.Popen(nmapCmd, shell=True, close_fds=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

        result = {}
        while True:
            # readline 增加maxSize参数是为了防止行过长，pipe buffer满了，行没结束，导致pipe写入阻塞
            line = child.stdout.readline(4096)
            if not line:
                break

            line = line.strip()
            line = line.decode('utf-8')
            if line != '' and not line.startswith('#') and not line.startswith('RTTVAR has grown to over'):
                hostInfo = None
                for field in line.split('\t'):
                    name = None
                    val = None
                    try:
                        (name, val) = field.split(': ', 2)
                    except:
                        print(line)

                    if name == 'Host':
                        (ip, dns) = val.split(' ')
                        ip = ip.strip()
                        hostInfo = result.get(ip)
                        if hostInfo is None:
                            hostInfo = {}
                            result[ip] = hostInfo

                        hostInfo['IP'] = ip
                        hostInfo['Name'] = dns[1:-1]
                    elif name == 'Ports':
                        for portSeg in val.split('/, '):
                            portsInfo = hostInfo.get('Ports')
                            if portsInfo is None:
                                portsInfo = {}
                                hostInfo['Ports'] = portsInfo

                            # (port, state, protocol, owner, service, rpc_info, version)
                            portParts = portSeg.split('/')
                            if portParts[1] == 'open':
                                portsInfo[portParts[0]] = {
                                    'port': portParts[0],
                                    'state': portParts[1],
                                    'protocol': portParts[2],
                                    'owner': portParts[3],
                                    'service': portParts[4],
                                    'rpc_info': portParts[5],
                                    'version': portParts[6]
                                }
                    elif name == 'Protocols':
                        for protocolSeg in val.split('/, '):
                            protocolsInfo = hostInfo.get('Protocols')
                            if protocolsInfo is None:
                                protocolsInfo = {}
                                hostInfo['Protocols'] = protocolsInfo

                            # (port, state, protocol, owner, service, rpc_info, version)
                            protocolParts = protocolSeg.split('/')
                            protocolsInfo[portParts[2]] = {
                                'number': portParts[0],
                                'state': portParts[1],
                                'name': portParts[2]
                            }
                    elif name == 'Status':
                        hostInfo['Status'] = val.strip()
                    else:
                        hostInfo[name] = val

        return result


def buildWorkerPool(workerCount, execQueue, snmpPort, db, pluginDir, communities, unknownItems):
    workers = []
    for i in range(workerCount):
        worker = SnmpCollectWorker('Worker-{}'.format(i), execQueue, snmpPort, db, pluginDir, communities, unknownItems)
        worker.setDaemon(True)
        worker.start()
        workers.append(worker)

    return workers


def discovery(snmpPort, workerCount, net, ports, timingTmpl):
    binDirs = os.path.split(os.path.realpath(__file__))
    pluginDir = os.path.realpath(binDirs[0])

    nmapScan = NmapScan()
    result = nmapScan.scan(net=net, ports=ports, timingTmpl=timingTmpl)

    (dbclient, db) = AutoExecUtils.getDB()

    snmpItems = []
    otherItems = []

    execQueue = queue.Queue(workerCount * 2)
    workerThreads = buildWorkerPool(workerCount, execQueue, snmpPort, db, pluginDir, communities, otherItems)

    hasError = 0
    try:
        for ip, hostInfo in result.items():
            objInfo = nmapScan.parse(hostInfo)
            if objInfo is not None:
                portsInfo = hostInfo.get('Ports')
                if portsInfo is not None:
                    if objInfo.get('_OBJ_CATEGORY') == 'OS':
                        otherItems.append(objInfo)
                    elif portsInfo.get(snmpPort) is not None:
                        execQueue.put(objInfo)
                    else:
                        objInfo['_OBJ_CATETORY'] = 'UNKNOWN'
                        otherItems.append(objInfo)
    finally:
        # 入队对应线程数量的退出信号对象
        for idx in range(1, workerCount*2):
            execQueue.put(None)
        # 等待所有worker线程退出
        while len(workerThreads) > 0:
            worker = workerThreads[-1]
            hasError = hasError + worker.errorCount
            worker.join(3)
            if not worker.is_alive():
                workerThreads.pop(-1)

        if dbclient is not None:
            dbclient.close()

    if len(otherItems) > 0:
        out = {'DATA': otherItems}
        AutoExecUtils.saveOutput(out)
        saveCmd = "{}/savedata --outputfile output.json".format(pluginDir)
        ret = os.system(saveCmd)
        hasError = hasError + ret

    return hasError


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--nets', default='192.168.0.0/24', help='Discovery nets, example:192.168.0.0/24,192.168.1.0/24')
    parser.add_argument('--ports', default='22,161,135,139,445,3389,3939', help='Scan ports, default:22,161,135,139,445,3389,3939')
    parser.add_argument('--snmpport', default=161, help='Snmp Port')
    parser.add_argument('--communities', default='["public"]', help='Snmp Communities(JSON Array), ecample:["public","mary"]')
    parser.add_argument('--workercount', default=16, help='Worker thread counts')
    parser.add_argument('--timingtmpl', default=4, help='Timing template, 1-5, 5 is fastest')

    args = parser.parse_args()

    snmpPort = int(args.snmpport)
    workerCount = int(args.workercount)
    timingTmpl = int(args.timingtmpl)

    ports = args.ports
    communities = json.loads(args.communities)

    autoexecHome = os.environ.get('AUTOEXEC_HOME')
    os.environ['PATH'] = '%s:%s/tools' % (os.environ['PATH'], autoexecHome)
    os.environ['OUTPUT_PATH'] = 'output.json'
    os.chdir(os.environ.get('AUTOEXEC_WORK_PATH'))

    for net in args.nets.split(','):
        discovery(snmpPort, workerCount, net, ports, timingTmpl)
